#!/usr/bin/env python

import os
import json
import pickle
import sys
import traceback
import pandas as pd
import datetime
from pathlib import Path
import logging

import torch

from transformers import (
    BertTokenizer,
    XLNetTokenizer,
    RobertaTokenizer,
    DistilBertTokenizer,
)


from fast_bert.data_cls import BertDataBunch
from fast_bert.learner_cls import BertLearner
from fast_bert.metrics import (
    accuracy,
    accuracy_multilabel,
    accuracy_thresh,
    fbeta,
    roc_auc,
)

run_start_time = datetime.datetime.today().strftime("%Y-%m-%d_%H-%M-%S")

channel_name = "training"

prefix = "/opt/ml/"
input_path = prefix + "input/data"  # opt/ml/input/data
code_path = prefix + "code"  # opt/ml/code
pretrained_model_path = (
    code_path + "/pretrained_models"
)  # opt/ml/code/pretrained_models

finetuned_path = input_path + "/{}/finetuned".format(
    channel_name
)  # opt/ml/input/data/training/finetuned

output_path = os.path.join(prefix, "output")  # opt/ml/output
model_path = os.path.join(prefix, "model")  # opt/ml/model

training_config_path = os.path.join(
    input_path, "{}/config".format(channel_name)
)  # opt/ml/input/data/training/config

hyperparam_path = os.path.join(
    prefix, "input/config/hyperparameters.json"
)  # opt/ml/input/config/hyperparameters.json
config_path = os.path.join(
    training_config_path, "training_config.json"
)  # opt/ml/input/data/training/config/training_config.json


# This algorithm has a single channel of input data called 'training'. Since we run in
# File mode, the input files are copied to the directory specified here.

training_path = os.path.join(input_path, channel_name)  # opt/ml/input/data/training


def searching_all_files(directory: Path):
    file_list = []  # A list for storing files existing in directories

    for x in directory.iterdir():
        if x.is_file():
            file_list.append(str(x))
        else:
            file_list.append(searching_all_files(x))

    return file_list


# The function to execute the training.
def train():

    print("Starting the training.")

    DATA_PATH = Path(training_path)
    LABEL_PATH = Path(training_path)

    try:
        print(config_path)
        with open(config_path, "r") as f:
            training_config = json.load(f)
            print(training_config)

        with open(hyperparam_path, "r") as tc:
            hyperparameters = json.load(tc)
            print(hyperparameters)

        # Logger
        # logfile = str(LOG_PATH/'log-{}-{}.txt'.format(run_start_time, training_config["run_text"]))
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
            datefmt="%m/%d/%Y %H:%M:%S",
            handlers=[
                # logging.FileHandler(logfile),
                logging.StreamHandler(sys.stdout)
            ],
        )

        logger = logging.getLogger()

        # Define pretrained model path
        PRETRAINED_PATH = Path(pretrained_model_path) / training_config["model_name"]
        logger.info("model path used {}".format(PRETRAINED_PATH))

        finetuned_model_name = training_config.get("finetuned_model", None)
        if finetuned_model_name is not None:
            finetuned_model = os.path.join(finetuned_path, finetuned_model_name)
            logger.info("finetuned model loaded from {}".format(finetuned_model))
        else:
            logger.info(
                "finetuned model not available - loading standard pretrained model"
            )
            finetuned_model = None

        # Tokenizer
        if training_config["model_type"] == "bert":
            tokenizer = BertTokenizer.from_pretrained(
                str(PRETRAINED_PATH),
                do_lower_case=bool(training_config["do_lower_case"]),
            )
        elif training_config["model_type"] == "distilbert":
            tokenizer = DistilBertTokenizer.from_pretrained(
                str(PRETRAINED_PATH),
                do_lower_case=bool(training_config["do_lower_case"]),
            )
        elif training_config["model_type"] == "roberta":
            tokenizer = RobertaTokenizer.from_pretrained(
                str(PRETRAINED_PATH),
                do_lower_case=bool(training_config["do_lower_case"]),
            )
        elif training_config["model_type"] == "xlnet":
            tokenizer = XLNetTokenizer.from_pretrained(
                str(PRETRAINED_PATH),
                do_lower_case=bool(training_config["do_lower_case"]),
            )

        device = torch.device("cuda")
        if torch.cuda.device_count() > 1:
            multi_gpu = True
        else:
            multi_gpu = False

        logger.info("Number of GPUs: {}".format(torch.cuda.device_count()))

        if bool(training_config["multi_label"]) == True:
            label_col = json.loads(training_config["label_col"])
        else:
            label_col = training_config["label_col"]

        logger.info("label columns: {}".format(label_col))

        # Create databunch
        databunch = BertDataBunch(
            DATA_PATH,
            LABEL_PATH,
            tokenizer,
            train_file=training_config["train_file"],
            val_file=training_config["val_file"],
            label_file=training_config["label_file"],
            text_col=training_config["text_col"],
            label_col=label_col,
            batch_size_per_gpu=int(hyperparameters["train_batch_size"]),
            max_seq_length=int(hyperparameters["max_seq_length"]),
            multi_gpu=multi_gpu,
            multi_label=bool(training_config["multi_label"]),
            model_type=training_config["model_type"],
            logger=logger,
        )

        metrics = []
        if bool(training_config["multi_label"]) == False:
            metrics.append({"name": "accuracy", "function": accuracy})
        else:
            metrics.append({"name": "accuracy_thresh", "function": accuracy_thresh})
            metrics.append({"name": "roc_auc", "function": roc_auc})
            metrics.append({"name": "fbeta", "function": fbeta})

        logger.info("databunch labels: {}".format(len(databunch.labels)))
        logger.info(
            "multilabel: {}, multilabel type: {}".format(
                bool(training_config["multi_label"]),
                type(bool(training_config["multi_label"])),
            )
        )

        # Initialise the learner
        learner = BertLearner.from_pretrained_model(
            databunch,
            PRETRAINED_PATH,
            metrics=metrics,
            device=device,
            logger=logger,
            output_dir=Path(model_path),
            finetuned_wgts_path=finetuned_model,
            is_fp16=bool(training_config["fp16"]),
            fp16_opt_level=training_config["fp16_opt_level"],
            warmup_steps=int(hyperparameters["warmup_steps"]),
            grad_accumulation_steps=int(training_config["grad_accumulation_steps"]),
            multi_gpu=multi_gpu,
            multi_label=bool(training_config["multi_label"]),
            logging_steps=int(training_config["logging_steps"]),
        )

        learner.fit(
            int(hyperparameters["epochs"]),
            float(hyperparameters["lr"]),
            schedule_type=hyperparameters["lr_schedule"],
            optimizer_type=hyperparameters["optimizer_type"],
        )

        # save model and tokenizer artefacts
        learner.save_model()

        # save model config file
        with open(os.path.join(model_path, "model_config.json"), "w") as f:
            json.dump(training_config, f)

        # save label file
        with open(os.path.join(model_path, "labels.csv"), "w") as f:
            f.write("\n".join(databunch.labels))

    except Exception as e:
        # Write out an error file. This will be returned as the failureReason in the
        # DescribeTrainingJob result.
        trc = traceback.format_exc()
        with open(os.path.join(output_path, "failure"), "w") as s:
            s.write("Exception during training: " + str(e) + "\n" + trc)
        # Printing this causes the exception to be in the training job logs, as well.
        print("Exception during training: " + str(e) + "\n" + trc, file=sys.stderr)
        # A non-zero exit code causes the training job to be marked as Failed.
        sys.exit(255)


if __name__ == "__main__":
    train()

    # A zero exit code causes the job to be marked a Succeeded.
    sys.exit(0)

